import torch.nn.functional as F
import yaml
import torch
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from torchvision import datasets, transforms
from torch import nn
import wandb
from pytorch_lightning.callbacks import Callback
import glob
import os
import numpy as np
import cv2

def unnormalize(image):
    min_val = image.min()
    max_val = image.max()
    image_clamped = torch.clamp(image, min_val, max_val)
    image_unnorm = (image_clamped - min_val) / (max_val - min_val)
    return image_unnorm
def extract_run_name(run_name_pattern):
    base_name = os.path.basename(run_name_pattern)
    run_name = base_name.split('.')[0]
    return run_name

def find_first_checkpoint(run_name_pattern):
    pattern = run_name_pattern.replace(extract_run_name(run_name_pattern), extract_run_name(run_name_pattern) + '*')
    checkpoint_list = glob.glob(pattern)
    
    return checkpoint_list[0] if checkpoint_list else None

def load_checkpoint(model, path):
    try:
        model.load_state_dict(torch.load(path))
        return path
    except FileNotFoundError:
        return None
    except Exception as e:
        return None

def divide_last_two_elements(input_tuple, n):
    temp_list = list(input_tuple)
    temp_list[-2] = temp_list[-2] // n
    temp_list[-1] = temp_list[-1] // n
    
    return tuple(temp_list)

def get_params_list(config, n, x_sigma = 0.1, l_sigma =0.1):

    my_shape = divide_last_two_elements(config['ORIGINAL_SHAPE'], n)
    if config['NUM_CLASSES'] == 100:
        my_classes = 100
    else:
        my_classes = config['NUM_CLASSES']
    num_classes = 1000
    num_targets = 100
    trainable_params = 0.5 + torch.randn(num_targets * config['NUM_SAMPLES'], *my_shape) * x_sigma
    trainable_lambda = nn.Parameter(torch.abs(torch.normal(mean = 1, std = l_sigma, size = (num_targets * config['NUM_SAMPLES'], 1))) + 5e-2)
    trainable_lambda.requires_grad = True 
    random_indices = torch.randperm(1000)[:num_targets]
    target = random_indices
    zero_tensor = torch.zeros((num_targets * config['NUM_SAMPLES'], num_targets))
    with torch.no_grad():
        trainable_params = nn.Parameter(trainable_params)
    trainable_params.requires_grad = True
    target = target.cuda()
    print(trainable_params.shape, trainable_lambda.shape, target.shape, zero_tensor.shape)
    return (trainable_params, trainable_lambda, target, zero_tensor)


def upscale_tensor(tensor, n=2):
    """
    This function upscales a tensor using nearest neighbor interpolation 
    without damaging the gradient graph.

    Args:
    tensor (torch.Tensor): The input tensor to be upscaled.
scale_factor (int): The scale factor for the upsampling.

    Returns:
    torch.Tensor: The upscaled tensor.
    """
    if len(tensor.shape) == 5:
        scale_factor = (1, n, n)
        return F.interpolate(tensor, scale_factor=scale_factor, mode='trilinear')
    else:
        scale_factor = (n, n)
        return F.interpolate(tensor, scale_factor=(n, n), mode='bilinear')


def downgrade_resol(tensor):
    original_size = tensor.shape[2:] 

    if len(tensor.shape) == 5: 
        scale_factor = (1, 2, 2)
        downscaled = downsample(tensor)
        return F.interpolate(downscaled, size=original_size, mode='trilinear')

    else:
        scale_factor = (2, 2)
        downscaled = downsample(tensor)
        return F.interpolate(downscaled, size=original_size, mode='bilinear')
def downsample(tensor):
    if len(tensor.shape) == 5: 
        maxpool = nn.AvgPool3d(kernel_size=(1, 2, 2))
        return maxpool(tensor)

    elif len(tensor.shape) == 4:
        maxpool = nn.AvgPool2d(kernel_size=2)
        return maxpool(tensor)

from torchvision.transforms.functional import to_pil_image

def read_yaml(file_path):
    with open(file_path, 'r') as stream:
        try:
            data = yaml.load(stream, Loader=yaml.FullLoader)
            return data
        except yaml.YAMLError as exc:
            print(exc)
            return None

def only_right_loss(output, target, samples_lambda, is_celeba = False):
    pred = output.argmax(dim=1, keepdim=True)

    incorrect_mask = pred.squeeze() != target

    CE_loss = F.nll_loss(output, target, reduction='none')
    CE_loss = CE_loss * (torch.nn.ReLU()(samples_lambda) + 1e-9)
    if torch.all(torch.logical_not(incorrect_mask)):
        return torch.tensor(0.0).cuda()
    incorrect_loss = CE_loss[incorrect_mask]
    
    return incorrect_loss
